{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, 3, padding=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 2))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, 3, stride=2, padding=(1,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 4), stride=(1, 1), padding=(1, 1))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, (3,4), padding=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 4), stride=(1, 1), padding=(1, 4))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, (3,4), padding=(1,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 4), stride=(3, 3), padding=(1, 2))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, (3,4), stride=(3,3), padding=(1,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 4), stride=(3, 3), padding=(1, 2))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, (3,4), stride=3, padding=(1,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.Conv2d(3, 16, 3, stride=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_pool = nn.MaxPool2d(3, stride=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.FloatTensor(3,5,5).random_(0, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[8., 0., 2., 9., 8.],\n",
       "         [0., 5., 4., 9., 0.],\n",
       "         [1., 3., 7., 4., 9.],\n",
       "         [5., 9., 4., 4., 6.],\n",
       "         [7., 8., 9., 2., 9.]],\n",
       "\n",
       "        [[6., 6., 1., 5., 9.],\n",
       "         [3., 8., 3., 6., 1.],\n",
       "         [9., 5., 4., 4., 2.],\n",
       "         [5., 6., 5., 6., 5.],\n",
       "         [7., 6., 2., 5., 6.]],\n",
       "\n",
       "        [[6., 6., 7., 4., 9.],\n",
       "         [4., 3., 2., 4., 4.],\n",
       "         [3., 0., 5., 7., 6.],\n",
       "         [6., 3., 1., 0., 4.],\n",
       "         [3., 3., 5., 0., 7.]]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[8., 9., 9.],\n",
       "         [9., 9., 9.],\n",
       "         [9., 9., 9.]],\n",
       "\n",
       "        [[9., 8., 9.],\n",
       "         [9., 8., 6.],\n",
       "         [9., 6., 6.]],\n",
       "\n",
       "        [[7., 7., 9.],\n",
       "         [6., 7., 7.],\n",
       "         [6., 7., 7.]]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_pool(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_pool = nn.AvgPool2d(3, stride=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[3.3333, 4.7778, 5.7778],\n",
       "         [4.2222, 5.4444, 5.2222],\n",
       "         [5.8889, 5.5556, 6.0000]],\n",
       "\n",
       "        [[5.0000, 4.6667, 3.8889],\n",
       "         [5.3333, 5.2222, 4.0000],\n",
       "         [5.4444, 4.7778, 4.3333]],\n",
       "\n",
       "        [[4.0000, 4.2222, 5.3333],\n",
       "         [3.0000, 2.7778, 3.6667],\n",
       "         [3.2222, 2.6667, 3.8889]]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "avg_pool(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ToTensor()"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.ToTensor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Resize(size=1, interpolation=PIL.Image.BILINEAR)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Resize(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Resize(size=(1, 1), interpolation=PIL.Image.BILINEAR)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Resize((1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pad(padding=1, fill=0, padding_mode=constant)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Pad(1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pad(padding=(1, 2), fill=1, padding_mode=constant)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Pad((1, 2), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pad(padding=(1, 2, 2, 3), fill=0, padding_mode=reflect)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Pad((1, 2, 2, 3), padding_mode='reflect')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Normalize(mean=(0.5,), std=(0.5,))"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Normalize((0.5,),(0.5,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    CenterCrop(size=(10, 10))\n",
       "    ToTensor()\n",
       ")"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Compose([\n",
    "     transforms.CenterCrop(10),\n",
    "     transforms.ToTensor(),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CenterCrop(size=(10, 10))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.CenterCrop(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomCrop(size=(10, 10), padding=None)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.RandomCrop(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomCrop(size=(10, 20), padding=None)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.RandomCrop((10, 20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomHorizontalFlip(p=0.3)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.RandomHorizontalFlip(p=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomVerticalFlip(p=0.3)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.RandomVerticalFlip(p=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ColorJitter(brightness=[0.75, 1.25], contrast=[0.75, 1.25], saturation=[0.75, 1.25], hue=[-0.25, 0.25])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.ColorJitter(0.25, 0.25, 0.25, 0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomRotation(degrees=(-10, 10), resample=False, expand=False)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.RandomRotation(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Compose(\n",
       "    RandomRotation(degrees=(-10, 10), resample=False, expand=False)\n",
       "    ToTensor()\n",
       ")"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Compose([\n",
    "     transforms.RandomRotation(10),\n",
    "     transforms.ToTensor(),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformations = transforms.Compose([\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.RandomRotation(20),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "train_data = datasets.CIFAR10('CIFAR10', train=True, download=True, transform=transformations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "test_data = datasets.CIFAR10('CIFAR10', train=False, download=True, transform=transformations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 10000)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data.sampler import SubsetRandomSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_size = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_size = len(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = list(range(training_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.shuffle(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_split = int(np.floor(training_size * validation_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_indices, training_indices = indices[:index_split], indices[index_split:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_sample = SubsetRandomSampler(training_indices)\n",
    "validation_sample = SubsetRandomSampler(validation_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data.dataloader import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_data, batch_size=batch_size, sampler=training_sample)\n",
    "valid_loader = DataLoader(train_data, batch_size=batch_size, sampler=validation_sample)\n",
    "test_loader = DataLoader(train_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)\n",
    "        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.linear1 = nn.Linear(64 * 4 * 4, 512)\n",
    "        self.linear2 = nn.Linear(512, 10) \n",
    "        self.dropout = nn.Dropout(p=0.3)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = self.pool(F.relu(self.conv3(x)))\n",
    "        x = x.view(-1, 64 * 4 * 4)\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.linear1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = self.linear2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNN()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (linear1): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (linear2): Linear(in_features=512, out_features=10, bias=True)\n",
       "  (dropout): Dropout(p=0.3, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device.type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.SGD(model.parameters(), lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(preds, y):\n",
    "    _, pred = torch.argmax(preds, 1)\n",
    "    correct = pred.eq(target.data.view_as(pred))\n",
    "    acc = correct.sum()/len(correct)\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Epoch: 01 | Train Loss: 2.057 | Val. Loss: 1.768 |\n",
      "| Epoch: 02 | Train Loss: 1.660 | Val. Loss: 1.516 |\n",
      "| Epoch: 03 | Train Loss: 1.507 | Val. Loss: 1.424 |\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, n_epochs+1):\n",
    "    train_loss = 0.0\n",
    "    valid_loss = 0.0\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item()*data.size(0)\n",
    "        \n",
    "    model.eval()\n",
    "    for batch_idx, (data, target) in enumerate(valid_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        valid_loss += loss.item()*data.size(0)\n",
    "        \n",
    "    train_loss = train_loss/len(train_loader.sampler)\n",
    "    valid_loss = valid_loss/len(valid_loader.sampler)\n",
    "    \n",
    "    print(f'| Epoch: {epoch:02} | Train Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f} |')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), \"cifar10.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
